Skip to content

Commit 5bb16b0

Browse files
Fix act config exporting for mixed schemes (#903)
* fp8 exporting bugfix Signed-off-by: Zhang, Weiwei1 <[email protected]> * fix act related config saving Signed-off-by: Zhang, Weiwei1 <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add ut for act_config check Signed-off-by: Zhang, Weiwei1 <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * refine extra_config saving, add UTs Signed-off-by: Zhang, Weiwei1 <[email protected]> * fix ut typo Signed-off-by: Zhang, Weiwei1 <[email protected]> * fix ut typo Signed-off-by: Zhang, Weiwei1 <[email protected]> * fixtypo Signed-off-by: Zhang, Weiwei1 <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix CI Signed-off-by: Zhang, Weiwei1 <[email protected]> * fix scan issue Signed-off-by: Zhang, Weiwei1 <[email protected]> * fix scan issue Signed-off-by: Zhang, Weiwei1 <[email protected]> * rm global variable Signed-off-by: Zhang, Weiwei1 <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * rerun ut Signed-off-by: Zhang, Weiwei1 <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * refine ut Signed-off-by: Zhang, Weiwei1 <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Zhang, Weiwei1 <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 0164077 commit 5bb16b0

File tree

5 files changed

+207
-70
lines changed

5 files changed

+207
-70
lines changed

auto_round/export/export_to_autoround/export.py

Lines changed: 13 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import json
1919
import os
2020
from concurrent.futures import ThreadPoolExecutor
21+
from dataclasses import fields
2122
from enum import Enum
2223

2324
import threadpoolctl as tctl
@@ -26,9 +27,10 @@
2627
import transformers
2728
from tqdm import tqdm
2829

29-
from auto_round.export.export_to_autoround.utils import REQUIRED_CONFIG_KEYS, check_neq_config
30+
from auto_round.export.export_to_autoround.utils import check_neq_config
3031
from auto_round.export.utils import save_model
3132
from auto_round.logger import logger
33+
from auto_round.schemes import QuantizationScheme
3234
from auto_round.utils import (
3335
SUPPORTED_FORMATS,
3436
SUPPORTED_LAYER_TYPES,
@@ -324,26 +326,20 @@ def save_quantized_as_autoround(output_dir, inplace=True, backend="auto_round:ex
324326
for i in range(len(block_name_to_quantize)):
325327
block_name_to_quantize[i] = os.path.commonprefix(block_name_to_quantize[i]).rstrip(".")
326328

327-
for layer_name in layer_config:
328-
if (
329-
not layer_config[layer_name]["in_blocks"] and layer_config[layer_name]["bits"] <= 8
330-
): ##lm head ##TODO fix act and so on
331-
extra_config[layer_name] = {}
332-
extra_config[layer_name]["bits"] = layer_config[layer_name]["bits"]
333-
extra_config[layer_name]["data_type"] = layer_config[layer_name]["data_type"]
334-
extra_config[layer_name]["group_size"] = layer_config[layer_name]["group_size"]
335-
extra_config[layer_name]["sym"] = layer_config[layer_name]["sym"]
336-
elif layer_config[layer_name]["in_blocks"] or (
329+
scheme_keys = [f.name for f in fields(QuantizationScheme)]
330+
for layer_name, cfg in layer_config.items():
331+
if not cfg["in_blocks"] and cfg["bits"] <= 8: # lm head
332+
extra_config[layer_name] = {key: cfg.get(key) for key in scheme_keys}
333+
elif cfg["in_blocks"] or (
337334
block_name_to_quantize is not None and check_start_with_block_name(layer_name, block_name_to_quantize)
338335
):
339-
neq_keys = check_neq_config(
340-
layer_config[layer_name], **{k: quantization_config[k] for k in REQUIRED_CONFIG_KEYS}
341-
)
336+
neq_keys = check_neq_config(cfg, **{k: quantization_config[k] for k in scheme_keys})
342337
if len(neq_keys) > 0:
343338
extra_config[layer_name] = {}
344-
for key in neq_keys:
345-
if layer_config[layer_name][key] is not None:
346-
extra_config[layer_name][key] = layer_config[layer_name][key]
339+
for key in scheme_keys:
340+
if cfg[key] is not None:
341+
extra_config[layer_name][key] = cfg[key]
342+
347343
if len(extra_config) > 0:
348344
quantization_config["extra_config"] = extra_config
349345
names = list(layer_config.keys())

auto_round/export/export_to_autoround/export_to_fp8.py

Lines changed: 13 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -16,16 +16,18 @@
1616
import json
1717
import os
1818
from concurrent.futures import ThreadPoolExecutor
19+
from dataclasses import fields
1920

2021
import threadpoolctl as tctl
2122
import torch
2223
import transformers
2324
from tqdm import tqdm
2425

2526
from auto_round.data_type.utils import reshape_pad_tensor_by_group_size, revert_tensor_by_pad
26-
from auto_round.export.export_to_autoround.utils import REQUIRED_CONFIG_KEYS, check_neq_config
27+
from auto_round.export.export_to_autoround.utils import check_neq_config
2728
from auto_round.export.utils import save_model
2829
from auto_round.logger import logger
30+
from auto_round.schemes import QuantizationScheme
2931
from auto_round.utils import (
3032
SUPPORTED_LAYER_TYPES,
3133
_get_packing_device,
@@ -169,26 +171,20 @@ def save_quantized_as_autoround(output_dir, inplace=True, backend="auto_round",
169171
for i in range(len(block_name_to_quantize)):
170172
block_name_to_quantize[i] = os.path.commonprefix(block_name_to_quantize[i]).rstrip(".")
171173

172-
for layer_name in layer_config:
173-
if (
174-
not layer_config[layer_name]["in_blocks"] and layer_config[layer_name]["bits"] <= 8
175-
): ##lm head ##TODO fix act and so on
176-
extra_config[layer_name] = {}
177-
extra_config[layer_name]["bits"] = layer_config[layer_name]["bits"]
178-
extra_config[layer_name]["data_type"] = layer_config[layer_name]["data_type"]
179-
extra_config[layer_name]["group_size"] = layer_config[layer_name]["group_size"]
180-
extra_config[layer_name]["sym"] = layer_config[layer_name]["sym"]
181-
elif layer_config[layer_name]["in_blocks"] or (
174+
scheme_keys = [f.name for f in fields(QuantizationScheme)]
175+
for layer_name, cfg in layer_config.items():
176+
if not cfg["in_blocks"] and cfg["bits"] <= 8: # lm head
177+
extra_config[layer_name] = {key: cfg.get(key) for key in scheme_keys}
178+
elif cfg["in_blocks"] or (
182179
block_name_to_quantize is not None and check_start_with_block_name(layer_name, block_name_to_quantize)
183180
):
184-
neq_keys = check_neq_config(
185-
layer_config[layer_name], **{k: quantization_config[k] for k in REQUIRED_CONFIG_KEYS}
186-
)
181+
neq_keys = check_neq_config(cfg, **{k: quantization_config[k] for k in scheme_keys})
187182
if len(neq_keys) > 0:
188183
extra_config[layer_name] = {}
189-
for key in neq_keys:
190-
if layer_config[layer_name][key] is not None:
191-
extra_config[layer_name][key] = layer_config[layer_name][key]
184+
for key in scheme_keys:
185+
if cfg[key] is not None:
186+
extra_config[layer_name][key] = cfg[key]
187+
192188
if len(extra_config) > 0:
193189
quantization_config["extra_config"] = extra_config
194190
names = list(layer_config.keys())

auto_round/export/export_to_autoround/export_to_nvfp_mxfp.py

Lines changed: 13 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -17,16 +17,18 @@
1717
import json
1818
import os
1919
from concurrent.futures import ThreadPoolExecutor
20+
from dataclasses import fields
2021

2122
import threadpoolctl as tctl
2223
import torch
2324
import torch.nn as nn
2425
import transformers
2526
from tqdm import tqdm
2627

27-
from auto_round.export.export_to_autoround.utils import REQUIRED_CONFIG_KEYS, check_neq_config
28+
from auto_round.export.export_to_autoround.utils import check_neq_config
2829
from auto_round.export.utils import save_model
2930
from auto_round.logger import logger
31+
from auto_round.schemes import QuantizationScheme
3032
from auto_round.utils import (
3133
SUPPORTED_LAYER_TYPES,
3234
_get_packing_device,
@@ -195,26 +197,20 @@ def save_quantized_as_fp(output_dir, inplace=True, **kwargs):
195197
for i in range(len(block_name_to_quantize)):
196198
block_name_to_quantize[i] = os.path.commonprefix(block_name_to_quantize[i]).rstrip(".")
197199

198-
for layer_name in layer_config:
199-
if (
200-
not layer_config[layer_name]["in_blocks"] and layer_config[layer_name]["bits"] <= 8
201-
): ##lm head # TODO fix act and so on
202-
extra_config[layer_name] = {}
203-
extra_config[layer_name]["bits"] = layer_config[layer_name]["bits"]
204-
extra_config[layer_name]["data_type"] = layer_config[layer_name]["data_type"]
205-
extra_config[layer_name]["group_size"] = layer_config[layer_name]["group_size"]
206-
extra_config[layer_name]["sym"] = layer_config[layer_name]["sym"]
207-
elif layer_config[layer_name]["in_blocks"] or (
200+
scheme_keys = [f.name for f in fields(QuantizationScheme)]
201+
for layer_name, cfg in layer_config.items():
202+
if not cfg["in_blocks"] and cfg["bits"] <= 8: # lm head
203+
extra_config[layer_name] = {key: cfg.get(key) for key in scheme_keys}
204+
elif cfg["in_blocks"] or (
208205
block_name_to_quantize is not None and check_start_with_block_name(layer_name, block_name_to_quantize)
209206
):
210-
neq_keys = check_neq_config(
211-
layer_config[layer_name], **{k: quantization_config[k] for k in REQUIRED_CONFIG_KEYS}
212-
)
207+
neq_keys = check_neq_config(cfg, **{k: quantization_config[k] for k in scheme_keys})
213208
if len(neq_keys) > 0:
214209
extra_config[layer_name] = {}
215-
for key in neq_keys:
216-
if layer_config[layer_name][key] is not None:
217-
extra_config[layer_name][key] = layer_config[layer_name][key]
210+
for key in scheme_keys:
211+
if cfg[key] is not None:
212+
extra_config[layer_name][key] = cfg[key]
213+
218214
if len(extra_config) > 0:
219215
quantization_config["extra_config"] = extra_config
220216
names = list(layer_config.keys())

auto_round/export/export_to_autoround/utils.py

Lines changed: 10 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -12,36 +12,30 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
REQUIRED_CONFIG_KEYS = (
16-
"data_type",
17-
"bits",
18-
"group_size",
19-
"sym",
20-
"act_bits",
21-
"act_data_type",
22-
"act_group_size",
23-
"act_sym",
24-
"act_dynamic",
25-
)
15+
from dataclasses import fields
16+
from typing import List
2617

18+
from auto_round.schemes import QuantizationScheme
2719

28-
def check_neq_config(config: dict, **expected) -> dict[str, tuple]:
20+
21+
def check_neq_config(config: dict, **expected) -> List[str]:
2922
"""
3023
Compare a config dict against expected values.
3124
Ensures all required keys are present in both config and expected.
3225
3326
Returns:
34-
dict[str, tuple]: {key: (actual, expected)} for mismatched values.
27+
List[str]: [keys] for mismatched values.
3528
"""
29+
scheme_keys = [f.name for f in fields(QuantizationScheme)]
3630
# 1. Check missing from expected
37-
missing_expected = [k for k in REQUIRED_CONFIG_KEYS if k not in expected]
31+
missing_expected = [k for k in scheme_keys if k not in expected]
3832
if missing_expected:
3933
raise ValueError(f"Missing expected values for keys: {missing_expected}")
4034

4135
# 2. Check missing from layer config
42-
missing_config = [k for k in REQUIRED_CONFIG_KEYS if k not in config]
36+
missing_config = [k for k in scheme_keys if k not in config]
4337
if missing_config:
4438
raise ValueError(f"Missing config values for keys: {missing_config}")
4539

4640
# 3. Collect mismatches
47-
return {key: (config[key], expected[key]) for key in REQUIRED_CONFIG_KEYS if config[key] != expected[key]}
41+
return [key for key in scheme_keys if config[key] != expected[key] and config[key] is not None]

0 commit comments

Comments
 (0)