8080 get_layer_names_in_block ,
8181 get_module ,
8282 htcore ,
83+ is_complex_device_mapping ,
8384 is_debug_mode ,
8485 is_fp8_linear ,
8586 is_fp8_model ,
@@ -472,7 +473,15 @@ def _parse_and_set_scheme(self, scheme: Union[str, dict, QuantizationScheme], kw
472473 """Parse and set the quantization scheme."""
473474
474475 def _parse_and_set (scheme , kwargs ):
475- res = ""
476+ if kwargs .get ("data_type" , None ) and kwargs ["data_type" ].endswith ("_dq" ) and not scheme .startswith ("gguf" ):
477+ if "bits" not in kwargs :
478+ data_type = kwargs ["data_type" ]
479+ raise KeyError (
480+ f"please set bits when setting data_type={ data_type } , or using scheme as an alternative.."
481+ )
482+ bits = kwargs ["bits" ]
483+ scheme = f"gguf:q{ bits } _k" if bits == 6 else f"gguf:q{ bits } _k_s"
484+ res = None
476485 if isinstance (scheme , QuantizationScheme ):
477486 scheme = asdict (scheme )
478487 elif isinstance (scheme , dict ):
@@ -1205,7 +1214,7 @@ def _quantize_layer_via_rtn(self, name: str) -> None:
12051214 m = get_module (self .model , name )
12061215
12071216 if is_fp8_linear (m ):
1208- m = convert_fp8_layer_to_linear (m , self .amp_dtype )
1217+ m = convert_fp8_layer_to_linear (m , self .amp_dtype , self . device )
12091218 set_module (self .model , name , m )
12101219
12111220 # Step 1: Try quantization on GPU first, fall back to CPU if OOM
@@ -1358,7 +1367,7 @@ def _quantize_rtn(self) -> tuple[torch.nn.Module, dict[str, Any]]:
13581367 cnt += 1
13591368 # Convert remaining fp8
13601369 if is_fp8_model (self .model ):
1361- convert_fp8_model_to_16b_model (self .model , self .amp_dtype )
1370+ convert_fp8_model_to_16b_model (self .model , self .amp_dtype , self . device )
13621371 self .quantized = True
13631372 return self .model , self .layer_config
13641373
@@ -1424,16 +1433,15 @@ def _quantize_via_rtn_blockwise(self, all_to_quantized_module_names: list[str])
14241433 for block_name in block_names :
14251434 pbar .set_description (f"Quantizing { block_name } " )
14261435 block = get_module (self .model , block_name )
1427- block = block .to (self .device )
14281436 if is_fp8_model (self .model ):
1429- convert_fp8_model_to_16b_model (block , dtype = self .amp_dtype )
1437+ convert_fp8_model_to_16b_model (block , dtype = self .amp_dtype , device = self . device )
14301438
1431- if self . device_map == "auto" or ( isinstance ( self . device_map , str ) and "," in self .device_map ):
1439+ if is_complex_device_mapping ( self .device_map ):
14321440 set_auto_device_map_for_block_with_tuning (
14331441 block , self .device_map , input_ids , self .low_gpu_mem_usage , self .batch_size
14341442 )
14351443 # Dispatch model if needed
1436- if self .device_map is not None :
1444+ if is_complex_device_mapping ( self .device_map ) :
14371445 from accelerate .hooks import AlignDevicesHook , add_hook_to_module
14381446
14391447 for _ , m in block .named_modules ():
@@ -1451,7 +1459,7 @@ def _quantize_via_rtn_blockwise(self, all_to_quantized_module_names: list[str])
14511459 self .device ,
14521460 self .cache_device ,
14531461 )
1454- if self .device_map is not None :
1462+ if is_complex_device_mapping ( self .device_map ) :
14551463 accelerate .hooks .remove_hook_from_submodules (block )
14561464
14571465 if is_nv_fp (self .act_data_type ) or is_static_wfp8afp8 (self ):
@@ -1630,7 +1638,7 @@ def quantize(self) -> tuple[torch.nn.Module, dict[str, Any]]:
16301638 if is_fp8_model (self .model ):
16311639 for n , m in self .model .named_modules ():
16321640 if is_fp8_linear (m ):
1633- new_layer = convert_fp8_layer_to_linear (m , self .amp_dtype ).to ("cpu" )
1641+ new_layer = convert_fp8_layer_to_linear (m , self .amp_dtype , self . device ).to ("cpu" )
16341642 set_module (self .model , n , new_layer )
16351643
16361644 end_time = time .time ()
@@ -1678,8 +1686,8 @@ def _quantize_layers(self, layer_names: list, layer_inputs: dict) -> None:
16781686
16791687 layer = get_module (self .model , layer_name )
16801688 layer = layer .to (self .device )
1681- if is_fp8_model ( self . model ):
1682- new_layer = convert_fp8_layer_to_linear (layer , self .amp_dtype ).to (self .device )
1689+ if is_fp8_linear ( layer ):
1690+ new_layer = convert_fp8_layer_to_linear (layer , self .amp_dtype , self . device ).to (self .device )
16831691 set_module (self .model , layer_name , new_layer )
16841692 layer = new_layer
16851693
@@ -2445,17 +2453,17 @@ def _quantize_block(
24452453 if is_fp8_model (self .model ):
24462454 for n , m in block .named_modules ():
24472455 if is_fp8_linear (m ):
2448- new_layer = convert_fp8_layer_to_linear (m , self .amp_dtype ).to (device )
2456+ new_layer = convert_fp8_layer_to_linear (m , self .amp_dtype , self . device ).to (device )
24492457 set_module (block , n , new_layer )
24502458
2451- if self . device_map == "auto" or (( isinstance ( self .device_map , str ) and "," in self . device_map ) ):
2459+ if is_complex_device_mapping ( self .device_map ):
24522460 set_auto_device_map_for_block_with_tuning (
24532461 block , self .device_map , input_ids , self .low_gpu_mem_usage , self .batch_size , device
24542462 )
24552463 else :
24562464 block = block .to (device )
24572465
2458- if self .device_map is not None :
2466+ if is_complex_device_mapping ( self .device_map ) :
24592467 for n , m in block .named_modules ():
24602468 if len (list (m .children ())) != 0 or not hasattr (m , "tuning_device" ):
24612469 continue
@@ -2653,15 +2661,15 @@ def _quantize_block(
26532661 device ,
26542662 cache_device = self .cache_device ,
26552663 )
2656- if self .device_map is not None :
2664+ if is_complex_device_mapping ( self .device_map ) :
26572665 accelerate .hooks .remove_hook_from_submodules (block )
26582666 mv_module_from_gpu (block )
26592667 clear_memory (input_ids )
26602668
26612669 return q_outputs , output
26622670
26632671 else :
2664- if self .device_map is not None :
2672+ if is_complex_device_mapping ( self .device_map ) :
26652673 accelerate .hooks .remove_hook_from_submodules (block )
26662674 mv_module_from_gpu (block )
26672675 clear_memory (input_ids )
@@ -2741,6 +2749,16 @@ def _quantize_blocks(
27412749 except (ImportError , ModuleNotFoundError ):
27422750 logger .error ("algorithm extension import error, fallback to default mode" )
27432751 quantize_block = self ._quantize_block
2752+ elif self .enable_alg_ext and self .data_type .endswith ("dq" ):
2753+ try :
2754+ from auto_round .alg_ext import dq_quantize_block_ext
2755+
2756+ BaseCompressor .dq_quantize_block_ext = dq_quantize_block_ext
2757+ quantize_block = self .dq_quantize_block_ext
2758+ logger .info ("using algorithm extension for quantization." )
2759+ except (ImportError , ModuleNotFoundError ):
2760+ logger .error ("algorithm extension import error, fallback to default mode" )
2761+ quantize_block = self ._quantize_block
27442762 else :
27452763 quantize_block = self ._quantize_block
27462764
0 commit comments