Skip to content

Commit be146c7

Browse files
authored
fix guff scheme and device_map bug (#969)
1 parent 4afbe0a commit be146c7

File tree

5 files changed

+71
-27
lines changed

5 files changed

+71
-27
lines changed

auto_round/alg_ext.abi3.so

360 KB
Binary file not shown.

auto_round/compressors/base.py

Lines changed: 34 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@
8080
get_layer_names_in_block,
8181
get_module,
8282
htcore,
83+
is_complex_device_mapping,
8384
is_debug_mode,
8485
is_fp8_linear,
8586
is_fp8_model,
@@ -472,7 +473,15 @@ def _parse_and_set_scheme(self, scheme: Union[str, dict, QuantizationScheme], kw
472473
"""Parse and set the quantization scheme."""
473474

474475
def _parse_and_set(scheme, kwargs):
475-
res = ""
476+
if kwargs.get("data_type", None) and kwargs["data_type"].endswith("_dq") and not scheme.startswith("gguf"):
477+
if "bits" not in kwargs:
478+
data_type = kwargs["data_type"]
479+
raise KeyError(
480+
f"please set bits when setting data_type={data_type}, or using scheme as an alternative.."
481+
)
482+
bits = kwargs["bits"]
483+
scheme = f"gguf:q{bits}_k" if bits == 6 else f"gguf:q{bits}_k_s"
484+
res = None
476485
if isinstance(scheme, QuantizationScheme):
477486
scheme = asdict(scheme)
478487
elif isinstance(scheme, dict):
@@ -1205,7 +1214,7 @@ def _quantize_layer_via_rtn(self, name: str) -> None:
12051214
m = get_module(self.model, name)
12061215

12071216
if is_fp8_linear(m):
1208-
m = convert_fp8_layer_to_linear(m, self.amp_dtype)
1217+
m = convert_fp8_layer_to_linear(m, self.amp_dtype, self.device)
12091218
set_module(self.model, name, m)
12101219

12111220
# Step 1: Try quantization on GPU first, fall back to CPU if OOM
@@ -1358,7 +1367,7 @@ def _quantize_rtn(self) -> tuple[torch.nn.Module, dict[str, Any]]:
13581367
cnt += 1
13591368
# Convert remaining fp8
13601369
if is_fp8_model(self.model):
1361-
convert_fp8_model_to_16b_model(self.model, self.amp_dtype)
1370+
convert_fp8_model_to_16b_model(self.model, self.amp_dtype, self.device)
13621371
self.quantized = True
13631372
return self.model, self.layer_config
13641373

@@ -1424,16 +1433,15 @@ def _quantize_via_rtn_blockwise(self, all_to_quantized_module_names: list[str])
14241433
for block_name in block_names:
14251434
pbar.set_description(f"Quantizing {block_name}")
14261435
block = get_module(self.model, block_name)
1427-
block = block.to(self.device)
14281436
if is_fp8_model(self.model):
1429-
convert_fp8_model_to_16b_model(block, dtype=self.amp_dtype)
1437+
convert_fp8_model_to_16b_model(block, dtype=self.amp_dtype, device=self.device)
14301438

1431-
if self.device_map == "auto" or (isinstance(self.device_map, str) and "," in self.device_map):
1439+
if is_complex_device_mapping(self.device_map):
14321440
set_auto_device_map_for_block_with_tuning(
14331441
block, self.device_map, input_ids, self.low_gpu_mem_usage, self.batch_size
14341442
)
14351443
# Dispatch model if needed
1436-
if self.device_map is not None:
1444+
if is_complex_device_mapping(self.device_map):
14371445
from accelerate.hooks import AlignDevicesHook, add_hook_to_module
14381446

14391447
for _, m in block.named_modules():
@@ -1451,7 +1459,7 @@ def _quantize_via_rtn_blockwise(self, all_to_quantized_module_names: list[str])
14511459
self.device,
14521460
self.cache_device,
14531461
)
1454-
if self.device_map is not None:
1462+
if is_complex_device_mapping(self.device_map):
14551463
accelerate.hooks.remove_hook_from_submodules(block)
14561464

14571465
if is_nv_fp(self.act_data_type) or is_static_wfp8afp8(self):
@@ -1630,7 +1638,7 @@ def quantize(self) -> tuple[torch.nn.Module, dict[str, Any]]:
16301638
if is_fp8_model(self.model):
16311639
for n, m in self.model.named_modules():
16321640
if is_fp8_linear(m):
1633-
new_layer = convert_fp8_layer_to_linear(m, self.amp_dtype).to("cpu")
1641+
new_layer = convert_fp8_layer_to_linear(m, self.amp_dtype, self.device).to("cpu")
16341642
set_module(self.model, n, new_layer)
16351643

16361644
end_time = time.time()
@@ -1678,8 +1686,8 @@ def _quantize_layers(self, layer_names: list, layer_inputs: dict) -> None:
16781686

16791687
layer = get_module(self.model, layer_name)
16801688
layer = layer.to(self.device)
1681-
if is_fp8_model(self.model):
1682-
new_layer = convert_fp8_layer_to_linear(layer, self.amp_dtype).to(self.device)
1689+
if is_fp8_linear(layer):
1690+
new_layer = convert_fp8_layer_to_linear(layer, self.amp_dtype, self.device).to(self.device)
16831691
set_module(self.model, layer_name, new_layer)
16841692
layer = new_layer
16851693

@@ -2445,17 +2453,17 @@ def _quantize_block(
24452453
if is_fp8_model(self.model):
24462454
for n, m in block.named_modules():
24472455
if is_fp8_linear(m):
2448-
new_layer = convert_fp8_layer_to_linear(m, self.amp_dtype).to(device)
2456+
new_layer = convert_fp8_layer_to_linear(m, self.amp_dtype, self.device).to(device)
24492457
set_module(block, n, new_layer)
24502458

2451-
if self.device_map == "auto" or ((isinstance(self.device_map, str) and "," in self.device_map)):
2459+
if is_complex_device_mapping(self.device_map):
24522460
set_auto_device_map_for_block_with_tuning(
24532461
block, self.device_map, input_ids, self.low_gpu_mem_usage, self.batch_size, device
24542462
)
24552463
else:
24562464
block = block.to(device)
24572465

2458-
if self.device_map is not None:
2466+
if is_complex_device_mapping(self.device_map):
24592467
for n, m in block.named_modules():
24602468
if len(list(m.children())) != 0 or not hasattr(m, "tuning_device"):
24612469
continue
@@ -2653,15 +2661,15 @@ def _quantize_block(
26532661
device,
26542662
cache_device=self.cache_device,
26552663
)
2656-
if self.device_map is not None:
2664+
if is_complex_device_mapping(self.device_map):
26572665
accelerate.hooks.remove_hook_from_submodules(block)
26582666
mv_module_from_gpu(block)
26592667
clear_memory(input_ids)
26602668

26612669
return q_outputs, output
26622670

26632671
else:
2664-
if self.device_map is not None:
2672+
if is_complex_device_mapping(self.device_map):
26652673
accelerate.hooks.remove_hook_from_submodules(block)
26662674
mv_module_from_gpu(block)
26672675
clear_memory(input_ids)
@@ -2741,6 +2749,16 @@ def _quantize_blocks(
27412749
except (ImportError, ModuleNotFoundError):
27422750
logger.error("algorithm extension import error, fallback to default mode")
27432751
quantize_block = self._quantize_block
2752+
elif self.enable_alg_ext and self.data_type.endswith("dq"):
2753+
try:
2754+
from auto_round.alg_ext import dq_quantize_block_ext
2755+
2756+
BaseCompressor.dq_quantize_block_ext = dq_quantize_block_ext
2757+
quantize_block = self.dq_quantize_block_ext
2758+
logger.info("using algorithm extension for quantization.")
2759+
except (ImportError, ModuleNotFoundError):
2760+
logger.error("algorithm extension import error, fallback to default mode")
2761+
quantize_block = self._quantize_block
27442762
else:
27452763
quantize_block = self._quantize_block
27462764

auto_round/utils/device.py

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -355,6 +355,19 @@ def get_packing_device(device: str | torch.device | None = "auto") -> torch.devi
355355
raise TypeError(f"Unsupported device type: {type(device)} ({device})")
356356

357357

358+
def is_complex_device_mapping(device_map):
359+
if device_map is None or isinstance(device_map, int):
360+
return False
361+
elif device_map == "auto":
362+
return True
363+
elif isinstance(device_map, str) and "," in device_map:
364+
return True
365+
elif isinstance(device_map, dict):
366+
return True
367+
else:
368+
return False
369+
370+
358371
class CpuInfo(object):
359372
"""Get CPU Info."""
360373

@@ -598,15 +611,11 @@ def set_tuning_device_for_layer(model, name: str, device: str) -> None:
598611
def set_non_auto_device_map(
599612
model: torch.nn.Module, device_map: Union[str, int, dict], quant_layer_names: Union[None, list, tuple] = None
600613
) -> None:
601-
if not device_map:
602-
return
603-
if device_map == "auto":
604-
return
605-
if isinstance(device_map, str) and "," in device_map: # auto device map
606-
return
607-
if isinstance(device_map, int):
614+
if not device_map or device_map == "auto" or isinstance(device_map, int):
608615
return
609616
if isinstance(device_map, str):
617+
if "," in device_map: # auto device map
618+
return
610619
device_map = device_map.replace(" ", "")
611620
infos = device_map.split(",")
612621
device_map_dict = {}
@@ -840,7 +849,7 @@ def set_auto_device_map_for_block_with_tuning(
840849
num_devices = torch.xpu.device_count()
841850
device_name = "xpu"
842851
else:
843-
raise RuntimeError("No CUDA or XPU devices found.")
852+
return
844853
device_list = None
845854
if isinstance(device_map, str) and "," in device_map:
846855
device_list = [int(dev) for dev in device_map.split(",") if dev.isdigit()]

auto_round/utils/model.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -927,7 +927,7 @@ def check_seqlen_compatible(input_seqlen, tokenizer=None, model=None):
927927
)
928928

929929

930-
def convert_fp8_layer_to_linear(layer, dtype=torch.bfloat16):
930+
def convert_fp8_layer_to_linear(layer, dtype=torch.bfloat16, device: str = "cpu"):
931931
""" """
932932
from auto_round.schemes import QuantizationScheme
933933

@@ -939,6 +939,7 @@ def convert_fp8_layer_to_linear(layer, dtype=torch.bfloat16):
939939
for key in keys:
940940
setattr(new_layer, key, getattr(layer, key, None))
941941

942+
layer = layer.to(device)
942943
if layer.__class__.__name__ == "CompressedLinear":
943944
dq_weight = layer.compressor.decompress_module(layer)
944945
else:
@@ -948,7 +949,7 @@ def convert_fp8_layer_to_linear(layer, dtype=torch.bfloat16):
948949
return new_layer
949950

950951

951-
def convert_fp8_model_to_16b_model(model, dtype=torch.bfloat16):
952+
def convert_fp8_model_to_16b_model(model, dtype=torch.bfloat16, device: str = "cpu"):
952953
"""
953954
Convert a model with FP8 quantized layers to a model with 16-bit linear layers.
954955
This is useful for compatibility with other frameworks or for further processing.
@@ -958,7 +959,7 @@ def convert_fp8_model_to_16b_model(model, dtype=torch.bfloat16):
958959
cnt = 0
959960
for n, m in model.named_modules():
960961
if m.__class__.__name__ == "FP8Linear":
961-
new_module = convert_fp8_layer_to_linear(m, dtype=dtype)
962+
new_module = convert_fp8_layer_to_linear(m, dtype=dtype, device=device)
962963
set_module(model, n, new_module)
963964
cnt += 1
964965
if cnt % 10 == 0: # Tricky setting

test/test_cuda/test_alg_ext.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,3 +38,19 @@ def test_2bits(self):
3838
# wo alg ext 0.2084, with 0.2364
3939
self.assertGreater(result["results"]["lambada_openai"]["acc,none"], 0.22)
4040
shutil.rmtree(self.save_folder, ignore_errors=True)
41+
42+
def test_cli(self):
43+
import os
44+
45+
model_name = "/models/opt-125m"
46+
python_path = sys.executable
47+
48+
res = os.system(
49+
f"cd ../.. && CUDA_VISIBLE_DEVICES=0 {python_path} -m auto_round --model {model_name} --device auto --enable_alg_ext --avg_bits 2 --options=W2A16,W4A16 --ignore_scale_zp_bits"
50+
)
51+
if res > 0 or res == -1:
52+
assert False, "cmd line test fail, please have a check"
53+
54+
55+
if __name__ == "__main__":
56+
unittest.main()

0 commit comments

Comments
 (0)